Matrix Chain Multiplication¶

Algorithm¶

  • Create the partition half-table with size $n$ (number of matrices) and iitialize each cell with 0.
  • Cost of multiplying 2 matrices of dimensions $(m\times n)$ and $(n\times p) = m\times n\times p$.
  • For k = 0 to $n$:
    • For all indexes i = 0 to $n$:
      • Create pair-intervals of $i$ to $i + k$ and calculate their individual cost.
      • Set the value of the minimum cost in the cell $(i, i + k)$.
  • The value of cell $(0,0)$ is the answer.

Function to create splits¶

In [1]:
from typing import List, Tuple

# function to split an interval
def split_interval(i: int, j: int, verbose: bool = False) -> List[Tuple[int]]:
    if verbose:
        print(f"\nSplitting ({i}, {j})")
    splits = []
    for k in range(i, j):
        split = ((i, k), (k + 1, j))
        splits.append(split)
    return splits

Function to get cost of multiplying 2 matrices¶

In [2]:
def getCost(m1: Tuple[int], m2: Tuple[int]) -> int:
    if not m1[1] == m2[0]:
        raise ValueError("Invalid matrix dimensions")

    return m1[0] * m1[1] * m2[1]
In [3]:
from typing import List
from IPython.display import display, HTML

def displayTable(arr: List[List[int]]):
    n = len(arr[0])
    headers = f'''<thead><tr><td>&nbsp;</td>{''.join([f"<td>{n - i}</td>" for i in range(n)])}</tr></thead>'''
    htmlData = "<tbody>"
    for i in range(n):
        data = f"<tr><td>{i + 1}</td>"
        for cell in arr[i]:
            t = cell
            if cell == '':
                t = "&nbsp;"
            data += f"<td>{t}</td>"
        data += "</tr>"
        htmlData += data
    htmlData += "</tbody>"
    display(HTML(
        f"<table>{headers}{htmlData}</table><hr/>"
    ))
In [4]:
inf = float("inf")

class PartitionTable:
    def __init__(self, matrices_: List[Tuple[int]]):
        self.matrices = matrices_[:]
        self.partition_table = []
        self.table_size = len(matrices_)

        # TODO: initialize the partition table
        k = self.table_size
        for i in range(self.table_size):
            arr = []
            for j in range(k):
                arr.append(0)
            k -= 1
            self.partition_table.append(arr[:])

        # function to map index to table
        self.map_index = lambda x, y: (x - 1, self.table_size - y)

    # setter function
    def setValue(self, x_: int, y_: int, val_: str):
        x, y = self.map_index(x_, y_)
        self.partition_table[x][y] = val_

    # getter function
    def getValue(self, x_: int, y_: int) -> str:
        x, y = self.map_index(x_, y_)
        return self.partition_table[x][y]

    # TODO: calculate cost of a partition
    def getCompositeValue(self, partition: Tuple[Tuple[int]]) -> int:
        cell1, cell2 = partition[0], partition[1]
        matrix1 = self.matrices[cell1[0] - 1] if cell1[0] == cell1[1] else (
            self.matrices[cell1[0] - 1][0], self.matrices[cell1[1] - 1][1]
        )
        matrix2 = self.matrices[cell2[0] - 1] if cell2[0] == cell2[1] else (
            self.matrices[cell2[0] - 1][0], self.matrices[cell2[1] - 1][1]
        )
        cost = getCost(matrix1, matrix2)
        value1 = self.getValue(cell1[0], cell1[1])
        value2 = self.getValue(cell2[0], cell2[1])
        return value1 + value2 + cost

    def createValue(self, x_: int, y_: int, verbose: bool):
        if x_ > y_ or not self.getValue(x_, y_) == 0:
            return

        if x_ == y_:
            # self.setValue(x_, y_, 0)
            return

        currentValue = inf
        splits = split_interval(x_, y_, verbose)
        for split in splits:
            value = self.getCompositeValue(split)
            if verbose:
                print(f"Cost of split {split[0]} - {split[1]} = {value}")
            if value < currentValue:
                currentValue = value

        if verbose:
            print(f"Min. cost of ({x_}, {y_}) = {currentValue}")
        self.setValue(x_, y_, currentValue)

    def __call__(self, verbose = False) -> int:
        for k in range(1, self.table_size):
            if verbose:
                display(HTML(f"<h2>Step:- {k}</h2>"))
            for i in range(1, self.table_size + 1):
                start, end = i, i + k
                if start > self.table_size or end > self.table_size:
                    continue
                self.createValue(start, end, verbose)
            if verbose:
                displayTable(self.partition_table)

        return self.partition_table[0][0]

Driver code 1¶

In [5]:
matrices = [(10, 20), (20, 30), (30, 40), (40, 50)]

solver = PartitionTable(matrices)
print(f"Answer = {solver(verbose = True)}")

Step:- 1

Splitting (1, 2)
Cost of split (1, 1) - (2, 2) = 6000
Min. cost of (1, 2) = 6000

Splitting (2, 3)
Cost of split (2, 2) - (3, 3) = 24000
Min. cost of (2, 3) = 24000

Splitting (3, 4)
Cost of split (3, 3) - (4, 4) = 60000
Min. cost of (3, 4) = 60000
 4321
10060000
20240000
3600000
40

Step:- 2

Splitting (1, 3)
Cost of split (1, 1) - (2, 3) = 32000
Cost of split (1, 2) - (3, 3) = 18000
Min. cost of (1, 3) = 18000

Splitting (2, 4)
Cost of split (2, 2) - (3, 4) = 90000
Cost of split (2, 3) - (4, 4) = 64000
Min. cost of (2, 4) = 64000
 4321
101800060000
264000240000
3600000
40

Step:- 3

Splitting (1, 4)
Cost of split (1, 1) - (2, 4) = 74000
Cost of split (1, 2) - (3, 4) = 81000
Cost of split (1, 3) - (4, 4) = 38000
Min. cost of (1, 4) = 38000
 4321
1380001800060000
264000240000
3600000
40

Answer = 38000

Driver code 2¶

In [6]:
matrices = [(5, 10), (10, 3), (3, 12), (12, 5)]

solver = PartitionTable(matrices)
print(f"Answer = {solver(verbose = True)}")

Step:- 1

Splitting (1, 2)
Cost of split (1, 1) - (2, 2) = 150
Min. cost of (1, 2) = 150

Splitting (2, 3)
Cost of split (2, 2) - (3, 3) = 360
Min. cost of (2, 3) = 360

Splitting (3, 4)
Cost of split (3, 3) - (4, 4) = 180
Min. cost of (3, 4) = 180
 4321
1001500
203600
31800
40

Step:- 2

Splitting (1, 3)
Cost of split (1, 1) - (2, 3) = 960
Cost of split (1, 2) - (3, 3) = 330
Min. cost of (1, 3) = 330

Splitting (2, 4)
Cost of split (2, 2) - (3, 4) = 330
Cost of split (2, 3) - (4, 4) = 960
Min. cost of (2, 4) = 330
 4321
103301500
23303600
31800
40

Step:- 3

Splitting (1, 4)
Cost of split (1, 1) - (2, 4) = 580
Cost of split (1, 2) - (3, 4) = 405
Cost of split (1, 3) - (4, 4) = 630
Min. cost of (1, 4) = 405
 4321
14053301500
23303600
31800
40

Answer = 405

Driver code 3¶

In [7]:
matrices = [(5,4),(4,6),(6,2),(2,7),(7,3)]

solver = PartitionTable(matrices)
print(f"Answer = {solver(verbose = True)}")

Step:- 1

Splitting (1, 2)
Cost of split (1, 1) - (2, 2) = 120
Min. cost of (1, 2) = 120

Splitting (2, 3)
Cost of split (2, 2) - (3, 3) = 48
Min. cost of (2, 3) = 48

Splitting (3, 4)
Cost of split (3, 3) - (4, 4) = 84
Min. cost of (3, 4) = 84

Splitting (4, 5)
Cost of split (4, 4) - (5, 5) = 42
Min. cost of (4, 5) = 42
 54321
10001200
200480
30840
4420
50

Step:- 2

Splitting (1, 3)
Cost of split (1, 1) - (2, 3) = 88
Cost of split (1, 2) - (3, 3) = 180
Min. cost of (1, 3) = 88

Splitting (2, 4)
Cost of split (2, 2) - (3, 4) = 252
Cost of split (2, 3) - (4, 4) = 104
Min. cost of (2, 4) = 104

Splitting (3, 5)
Cost of split (3, 3) - (4, 5) = 78
Cost of split (3, 4) - (5, 5) = 210
Min. cost of (3, 5) = 78
 54321
100881200
20104480
378840
4420
50

Step:- 3

Splitting (1, 4)
Cost of split (1, 1) - (2, 4) = 244
Cost of split (1, 2) - (3, 4) = 414
Cost of split (1, 3) - (4, 4) = 158
Min. cost of (1, 4) = 158

Splitting (2, 5)
Cost of split (2, 2) - (3, 5) = 150
Cost of split (2, 3) - (4, 5) = 114
Cost of split (2, 4) - (5, 5) = 188
Min. cost of (2, 5) = 114
 54321
10158881200
2114104480
378840
4420
50

Step:- 4

Splitting (1, 5)
Cost of split (1, 1) - (2, 5) = 174
Cost of split (1, 2) - (3, 5) = 288
Cost of split (1, 3) - (4, 5) = 160
Cost of split (1, 4) - (5, 5) = 263
Min. cost of (1, 5) = 160
 54321
1160158881200
2114104480
378840
4420
50

Answer = 160